/**
 * 
 */
package de.uni_postdam.ling.tcl.util;

import org.apache.commons.math3.random.RandomGenerator;

/**
 * This class can be used to draw indexes from an array of weights where the probability of any
 * index being drawn is proportional to e^w where w is the weight in the array at the given
 * index.
 * 
 * @author Christoph Teichmann
 *
 */
public class ArraySampler
{
	/**
	 * Passed when the search does not properly terminate, which should only happen, if the array
	 * used is not a CDF
	 */
	private static final String	NOTHING_FOUND	= "Something went wront, have you actually passed a CDF?";

	/**
	 * Zero
	 */
	private static final double	ZERO	= 0.0;

	/**
	 * One
	 */
	private static final double	ONE	= 1.0;
	
	/**
	 * The random number generator used to draw indexes.
	 */
	private final RandomGenerator rg;
	
	/**
	 * Creates a new instance that will use the given RandomGenerator in order to make its draws.
	 * 
	 * @param rg
	 */
	public ArraySampler(RandomGenerator rg)
	{
		super();
		this.rg = rg;
	}

	/**
	 * Will change the given array in place in order to turn it into a cumulative weight function.
	 * 
	 * This means that all values are exponentiated (base e) and normalized, then the values are
	 * added up, i.e. the first index will hold the normalized weight of the first field. The second
	 * index will hold the sum of the first and the second field, the third will be the sum of the
	 * first, second and third field and so on.
	 * 
	 * @param logWeights the weights that will be exponentiated in order to generate the proportional
	 * probabilities.
	 */
	public void turnIntoCWF(double[] logWeights)
	{
		double max = Double.NEGATIVE_INFINITY;
		
		for(double d : logWeights)
		{
			max = Math.max(max, d);
		}
		
		double sum = 0.0;
		
		for (int i = 0; i < logWeights.length; ++i)
		{
			sum += logWeights[i] = Math.exp(logWeights[i]-max);
		}
		
		double all = 0.0;
		
		for(int i=0;i<logWeights.length;++i)
		{
			logWeights[i] = all += (logWeights[i]/sum);
		}
	}

	/**
	 * Produces an index within 0, cwf.lenght-1 according to the given cumulative weight function as
	 * would be generated by the turnIntoCWF function. If the given array does not behave like a cwf,
	 * then the behavior is undefined and an IllegalStateException may be thrown.
	 * 
	 * The sampling algorithm is specifically designed to be fast and produces a sample in no more than
	 * O(log(cwf.length)) meaning it can still be applied to very large arrays.
	 * 
	 * @param cwf
	 * @return
	 */
	public int produceSample(double[] cwf)
	{
		/*
		 * the main idea behind the sampler is doing a version of binary search for a
		 * field in the cwf which has a value higher than the drawn double and where
		 * the predecessor field has a smaller value.
		*/
		
		double d = this.rg.nextDouble();
		
		int middle = cwf.length/2;
		int left = 0;
		int right = cwf.length-1;
		
		for(int i=0;i<cwf.length+1;++i)
		{
			double val = cwf[middle];
			
			/*
			 *if we are at exactly the correct value, then we can stop 
			 */
			if(val == d)
			{
				return middle;
			}
			
			if(d < val)
			{
				
				double lOne = getOneLess(middle,cwf);
				/*
				 * if this field is larger and the previous one is smaller, then we can again stop 
				 */
				if(lOne < d)
				{
					return middle;
				}
				/*
				 *otherwise we reduce our search field by drawing in the right border 
				 */
				right = middle;
			}
			else if(d > val)
			{
				double pOne = getOneMore(middle,cwf);
				/*
				 * if this field is smaller and the next one is greater, then we
				 * can just return the next one and are done
				 */
				if(pOne > d)
				{
					return middle+1;
				}
				/*
				 * otherwise we draw in the left border
				 */
				left = middle;
			}
			
			middle = left + ((right-left)/2);
		}
		throw new IllegalStateException(NOTHING_FOUND);
	}

	/**
	 * Returns the value of the field following middle, or 1.0 (complete probability), if no such
	 * field exists.
	 * 
	 * @param middle
	 * @param cwf
	 * @return
	 */
	private double getOneMore(int middle, double[] cwf)
	{
		if(middle == cwf.length)
		{
			return ONE;
		}
		else
		{
			return cwf[middle+1];
		}
	}

	/**
	 * Returns the value of the field preceding middle or 0 (no probability) if there is no such field. 
	 * 
	 * @param middle
	 * @param cwf
	 * @return
	 */
	private double getOneLess(int middle, double[] cwf)
	{
		if(middle == 0)
		{
			return ZERO;
		}
		else
		{
			return cwf[middle-1];
		}
	}
}
